Variational Gradient Matching for Dynamical Systems: Dynamic Causal Modeling

,

Authors:

Nico Stephan Gorbach and Stefan Bauer, email: nico.gorbach@gmail.com

Contents:

Instructional code for the NIPS (2018) paper Scalable Variational Inference for Dynamical Systems by Nico S. Gorbach, Stefan Bauer and Joachim M. Buhmann. Please cite our paper if you use our program for a further publication. The derivations in this document are also given in the doctoral thesis https://www.research-collection.ethz.ch/handle/20.500.11850/261734 as well as in parts of Wenk et al. (2018).
Example dynamical system used in this code: Dynamic Causal Modeling (visual attention system) with three hidden neuronal- and 12 hidden hemodynamic states. The system is affected by given external inputs and the states are only indirectly observed through the BOLD signal change equation.

User Input: Simulation Settings

Input the ODEs "type" used to generate the data as a string. Options: 'nonlinear_forward_modulation_by_attention', 'forward_modulation_and_driven_by_attention'', 'forward_modulation_by_attention', 'backward_modulation_by_attention', 'backward_modulation_and_driven_by_attention', 'absent_modulation', 'absent_attention_input', 'absent_photic_input', 'driven_by_attention', 'photic_input'.
simulation.odes = 'forward_modulation_and_driven_by_attention';
Input a cell vector containing the symbols (characters) in the '_ODEs.txt' file. Eg: to observe deoxyhemoglobin content, blood volume and blood flow set simulation.observed_states = {'q_1','q_3','q_2','v_1','v_3','v_2','f_1','f_3','f_2'}).
simulation.observed_states = {};
Input a positve real number:
simulation.final_time = 359*3.22;
Input a function handle:
simulation.state_obs_variance = @(x)(repmat(bsxfun(@rdivide,var(x),5),size(x,1),1));
Input a positive real number:
simulation.interval_between_observations = 0.1;

User Input: Estimation Settings

Input the ODEs "type" used for estimation as a string. Options: 'nonlinear_forward_modulation_by_attention', 'forward_modulation_and_driven_by_attention'', 'forward_modulation_by_attention', 'backward_modulation_by_attention', 'backward_modulation_and_driven_by_attention', 'absent_modulation', 'absent_attention_input', 'absent_photic_input', 'driven_by_attention', 'photic_input'.
candidate_odes = 'forward_modulation_and_driven_by_attention';
Input a row vector of positive real numbers of size 1 x 2:
kernel.param = [10,0.2];
Input a row vector of positive real numbers of size 1 x number of ODEs:
state.derivative_variance = 6.*ones(11-3,1);
Input a row vector of positive real numbers in ascending order:
time.est = 0:3.22:359*3.22;

Preliminary operations
close all; clc; addpath('VGM_functions')

Preprocessing for candidate ODEs

[symbols,ode,plot_settings,state,simulation,odes_path,coupling_idx,opt_settings] = ...
preprocessing_dynamic_causal_modeling (simulation,candidate_odes,state);
ODEs: / / / 3 \exp(-f_1) \ \ | 25 exp(-q_1) exp(f_1) | | - | - 1 | | | d q_1 \ \ 5 / / | | ----- == - #3 - -------------------------------------------- | | dt 16 | | | | / / 3 \exp(-f_3) \ | | 25 exp(-q_3) exp(f_3) | | - | - 1 | | | d q_3 \ \ 5 / / | | ----- == - #1 - -------------------------------------------- | | dt 16 | | | | / / 3 \exp(-f_2) \ | | 25 exp(-q_2) exp(f_2) | | - | - 1 | | | d q_2 \ \ 5 / / | | ----- == - #2 - -------------------------------------------- | | dt 16 | | | | d v_1 5 exp(-v_1) exp(f_1) | | ----- == -------------------- - #3 | | dt 8 | | | | d v_3 5 exp(-v_3) exp(f_3) | | ----- == -------------------- - #1 | | dt 8 | | | | d v_2 5 exp(-v_2) exp(f_2) | | ----- == -------------------- - #2 | | dt 8 | | | | d f_1 | | ----- == s_1 exp(-f_1) | | dt | | | | d f_3 | | ----- == s_3 exp(-f_3) | | dt | | | | d f_2 | | ----- == s_2 exp(-f_2) | | dt | | | | d s_1 3 s_1 8 exp(f_1) 8 | | ----- == n_1 - ----- - ---------- + -- | | dt 5 25 25 | | | | d s_3 3 s_3 8 exp(f_3) 8 | | ----- == n_3 - ----- - ---------- + -- | | dt 5 25 25 | | | | d s_2 3 s_2 8 exp(f_2) 8 | | ----- == n_2 - ----- - ---------- + -- | | dt 5 25 25 | | | | d n_1 | | ----- == a_11 n_1 + a_12 n_2 + c_11 u_1 | | dt | | | | d n_3 | | ----- == a_32 n_2 + a_33 n_3 + c_33 u_3 | | dt | | | | d n_2 | | ----- == a_22 n_2 + a_23 n_3 + n_1 (a_21 + b_212 u_2 + b_213 u_3) | \ dt / where / 17 v_3 \ exp| ------ | 5 \ 8 / #1 == --------------- 8 / 17 v_2 \ exp| ------ | 5 \ 8 / #2 == --------------- 8 / 17 v_1 \ exp| ------ | 5 \ 8 / #3 == --------------- 8

Simulate Trajectories

[symbols_true,ode_true] = preprocessing_dynamic_causal_modeling (simulation,simulation.odes,state);
ODEs: / / / 3 \exp(-f_1) \ \ | 25 exp(-q_1) exp(f_1) | | - | - 1 | | | d q_1 \ \ 5 / / | | ----- == - #3 - -------------------------------------------- | | dt 16 | | | | / / 3 \exp(-f_3) \ | | 25 exp(-q_3) exp(f_3) | | - | - 1 | | | d q_3 \ \ 5 / / | | ----- == - #1 - -------------------------------------------- | | dt 16 | | | | / / 3 \exp(-f_2) \ | | 25 exp(-q_2) exp(f_2) | | - | - 1 | | | d q_2 \ \ 5 / / | | ----- == - #2 - -------------------------------------------- | | dt 16 | | | | d v_1 5 exp(-v_1) exp(f_1) | | ----- == -------------------- - #3 | | dt 8 | | | | d v_3 5 exp(-v_3) exp(f_3) | | ----- == -------------------- - #1 | | dt 8 | | | | d v_2 5 exp(-v_2) exp(f_2) | | ----- == -------------------- - #2 | | dt 8 | | | | d f_1 | | ----- == s_1 exp(-f_1) | | dt | | | | d f_3 | | ----- == s_3 exp(-f_3) | | dt | | | | d f_2 | | ----- == s_2 exp(-f_2) | | dt | | | | d s_1 3 s_1 8 exp(f_1) 8 | | ----- == n_1 - ----- - ---------- + -- | | dt 5 25 25 | | | | d s_3 3 s_3 8 exp(f_3) 8 | | ----- == n_3 - ----- - ---------- + -- | | dt 5 25 25 | | | | d s_2 3 s_2 8 exp(f_2) 8 | | ----- == n_2 - ----- - ---------- + -- | | dt 5 25 25 | | | | d n_1 | | ----- == a_11 n_1 + a_12 n_2 + c_11 u_1 | | dt | | | | d n_3 | | ----- == a_32 n_2 + a_33 n_3 + c_33 u_3 | | dt | | | | d n_2 | | ----- == a_22 n_2 + a_23 n_3 + n_1 (a_21 + b_212 u_2 + b_213 u_3) | \ dt / where / 17 v_3 \ exp| ------ | 5 \ 8 / #1 == --------------- 8 / 17 v_2 \ exp| ------ | 5 \ 8 / #2 == --------------- 8 / 17 v_1 \ exp| ------ | 5 \ 8 / #3 == --------------- 8
Sample ODE parameters that lead to non-diverging trajectories:
non_diverging_trajectories = false; i = 0;
while ~non_diverging_trajectories
non-selfinhibitory neuronal couplings (sampled uniformily in the interval ):
simulation.ode_param = -0.8 + (0.8-(-0.8)) * rand(1,length(symbols_true.param));
% simulation.ode_param = [0.46,0.13,0.39,0.26,0.5,0.26,0.1,1.25,-1,-1,-1]; % published ODE parameters (slightly modified from Stephan et al., 2008)
self-inhibitory neuronal couplings set to -1:
simulation.ode_param(end-2:end) = -1;
try
simulation_old = simulation;
[simulation,obs_to_state_relation,fig_handle,plot_handle] = simulate_state_dynamics_dcm(...
simulation,symbols_true,ode_true,time,plot_settings,state.ext_input,'plot');
non_diverging_trajectories = 1;
end
end

Mass Action Dynamical Systems

A deterministic dynamical system is represented by a set of K ordinary differential equations (ODEs) with model parameters that describe the evolution of K states such that:
,
A sequence of observations, , is usually contaminated by measurement error which we assume to be normally distributed with zero mean and variance for each of the K states, i.e. , with . For N distinct time points the overall system may therefore be summarized as
,
where
,
,
and is the k'th state sequence and are the observations. Given the observations and the description of the dynamical system (1), the aim is to estimate both state variables and parameters .
We consider only dynamical systems that are locally linear with respect to ODE parameters and individual states . Such ODEs include mass-action kinetics and are given by:
,
with describing the state variables in each factor of the equation (i.e. the functions are linear in parameters and contain arbitrary large products of monomials of the states).

start timer
tic;

Prior on States and State Derivatives

Gradient matching with Gaussian processes assumes a joint Gaussian process prior on states and their derivatives:
,
with
,
,
,
.

Matching Gradients

Given the joint distribution over states and their derivatives (3) as well as the ODEs (2), we therefore have two expressions for the state derivatives:
,
,
where and is the error variance in the ODEs. Note that, in a deterministic system, the output of the ODEs should equal the state derivatives . However, in the first equation above we relax this contraint by adding stochasticity to the state derivatives in order to compensate for a
potential model mismatch. The second equation above is obtained by deriving the conditional distribution for from the joint distribution in equation (3). Equating the two expressions in the equations above we can eliminate the unknown state derivatives :
,
with .
[dC_times_invC,inv_C,A_plus_gamma_inv] = kernel_function(kernel,state,time.est);

Rewrite ODEs as Linear Combination in Parameters

Since, according to the mass action dynamics (equation 2), the ODEs are linear in the parameters we can rewrite the ODEs in equation (2) as a linear combination in the parameters:
,
where matrices and are defined such that the ODEs are expressed as a linear combination in .
[ode_param.lin_comb.B,ode_param.lin_comb.b] = rewrite_odes_as_linear_combination_in_parameters(ode,symbols);

Posterior over ODE Parameters

Inserting (5) into (4) and solving for yields:
,
where denotes the pseudo-inverse of . Since is block diagonal we can rewrite the expression above as:
,
where we subsitute the Moore-Penrose inverse for the pseudo-inverse (i.e. ). We can therefore derive the posterior distribution over ODE parameters:
.

Rewrite Hemodynamic ODEs as Linear Combination in (monotonic functions of) Individual Hemodynamic States

Rewrite the BOLD signal change equation as a linear combination in a monotonic function of the deoxyhemoglobin content :
.
[state.deoxyhemo.R,state.deoxyhemo.r] = rewrite_bold_signal_eqn_as_linear_combination_in_deoxyhemo(symbols);
Rewrite the deoxyhemoglobin content ODE as a linear combination in a monotonic function of the blood volume :
.
[state.vol.R,state.vol.r] = rewrite_deoxyhemo_ODE_as_linear_combination_in_vol(ode,symbols);
Rewrite the blood volume ODE as a linear combination in a monotonic function of the blood flow :
.
[state.flow.R,state.flow.r] = rewrite_vol_ODE_as_linear_combination_in_flow(ode,symbols);

Rewrite the blood flow and vasoginalling ODEs as a linear combination in vasosignalling :
.
.
[state.vaso.R,state.vaso.r] = rewrite_vaso_and_flow_odes_as_linear_combination_in_vaso(ode,symbols);

Rewrite Neuronal ODEs as Linear Combination in Individual Neuronal States

We rewrite the ODE(s) as a linear combination in the individual state :
,
where matrices and are defined such that the ODE is expressed as a linear combination in the individual state .
[state.neuronal.R,state.neuronal.r] = rewrite_odes_as_linear_combination_in_ind_neuronal_states(ode,symbols,coupling_idx.states);

Posterior over Individual States

Given the linear combination of the ODEs w.r.t. an individual state, we define the matrices such that the expression is rewritten as a linear combination in an individual state :
.
Inserting (7) into (4) and solving for yields:
,
where denotes the pseudo-inverse of . Since is block diagonal we can rewrite the expression above as:
,
where we subsitute the Moore-Penrose inverse for the pseudo-inverse (i.e. ). We can therefore derive the posterior distribution over an individual state :
,
with denoting the set of all states except state .

Mean-field Variational Inference

To infer the parameters , we want to find the maximum a posteriori estimate (MAP):
.
However, the integral above is intractable due to the strong couplings induced by the nonlinear ODEs which appear in the term .
We use mean-field variational inference to establish variational lower bounds that are analytically tractable by decoupling state variables from the ODE parameters as well as decoupling the state variables from each other. Note that, since the ODEs described by equation (2) are locally linear, both conditional distributions (equation (6)) and (equation (8)) are analytically tractable and Gaussian distributed as mentioned previously. The decoupling is induced by designing a variational distribution which is restricted to the family of factorial distributions:
.
The particular form of and are designed to be Gaussian distributed which places them in the same family as the true full conditional distributions. To find the optimal factorial distribution we minimize the Kullback-Leibler divergence between the variational and the true posterior distribution:
,
where is the proxy distribution. The proxy distribution that minimizes the KL-divergence (10) depends on the true full conditionals and is given by:
.

Denoising BOLD Observations

We denoise the BOLD observation by standard GP regression.
bold_response.denoised_obs = denoising_BOLD_observations(simulation.bold_response{:,{'n_1','n_3','n_2'}},inv_C,symbols,simulation);

Fitting observations of state trajectories

We fit the observations of state trajectories by standard GP regression. The data-informed distribution in euqation (9) can be determined analytically using Gaussian process regression with the GP prior :
,
where and .
[mu,inv_sigma] = fitting_state_observations(inv_C,obs_to_state_relation,simulation,symbols);

Coordinate Ascent Variational Gradient Matching

We minimize the KL-divergence in equation (10) by coordinate descent (where each step is analytically tractable) by iterating between determining the proxy for the distribution over ODE parameters and the proxies for the distribution over individual states .
state.proxy.mean = array2table([time.est',mu],'VariableNames',['time',symbols.state_string]);
bold_response.obs_old = bold_response.denoised_obs;
ode_param.proxy.mean = zeros(length(symbols.param),1);
for i = 1:opt_settings.coord_ascent_numb_iter

Proxy for Hemodynamic States

Determine the proxies for the states, starting with deoxyhemoglobin followed by blood volume, blood flow and finally vasosignalling. The information flow in the hemodynamic system is shown in its factor graph below:
.
The model inversion in the hemodynmic factor graph above occurs locally w.r.t. individual states. Given the expression for the BOLD signal change equation, we invert the BOLD signal change equation analytically to determine the deoxyhemoglobin content (1). The newly inferred deoxyhemoglobin content influences the expression for the factor associated with the change in deoxyhemoglobin content , which we subsequently invert analytically to infer the blood volume (2). Thereafter, we infer the blood flow (3) by inverting the factors associated with the change in blood volume as well as vasosignalling , followed by inferring vasosignalling (4) by inverting the factors associated with blood flow induction and vasosignalling . Finally, the neuronal dynamics (5) are learned, in part, by inverting the factor associated with vasosignalling . The typical trajectories of each of the states are shown (red) together with their iterative approximation (grey lines) obtained by graphical DCM.
Damping is required since we invert only the factor for the BOLD signal change equation w.r.t. a monotonic function of deoxyhemoglobin content .
Undamped proxy:
state_proxy_undamped = proxy_for_deoxyhemoglobin_content(state.deoxyhemo,state.proxy.mean{:,symbols.state_string},...
bold_response.denoised_obs,symbols,A_plus_gamma_inv,opt_settings);
Damped proxy:
state.proxy.mean{:,{'q_1','q_3','q_2'}} = (1-opt_settings.damping) * state.proxy.mean{:,{'q_1','q_3','q_2'}} + ...
opt_settings.damping * state_proxy_undamped;
Damping is required since we invert only the a subset of ODEs w.r.t. a monotonic function of blood volume .
Undamped proxy:
state_proxy_undamped = proxy_for_blood_volume(state.vol,dC_times_invC,state.proxy.mean{:,symbols.state_string},...
ode_param.proxy.mean,symbols,A_plus_gamma_inv,opt_settings);
Damped proxy:
state.proxy.mean{:,{'v_1','v_3','v_2'}} = (1-opt_settings.damping) * state.proxy.mean{:,{'v_1','v_3','v_2'}} + ...
opt_settings.damping * state_proxy_undamped;
Damping is required since we invert only the a subset of ODEs w.r.t. a mononic function of blood flow .
Undamped proxy:
state_proxy_undamped = proxy_for_blood_flow(state.flow,dC_times_invC,state.proxy.mean{:,symbols.state_string},...
ode_param.proxy.mean,symbols,A_plus_gamma_inv,opt_settings);
Damped proxy:
state.proxy.mean{:,{'f_1','f_3','f_2'}} = (1-opt_settings.damping) * state.proxy.mean{:,{'f_1','f_3','f_2'}} + ...
opt_settings.damping * state_proxy_undamped;
No damping is required because we invert all ODEs w.r.t. vasosingalling .
state.proxy.mean{:,{'s_1','s_3','s_2'}} = proxy_for_vasosignalling(state.vaso,dC_times_invC,...
state.proxy.mean{:,symbols.state_string},ode_param.proxy.mean,symbols,A_plus_gamma_inv,opt_settings);

Proxy for Neuronal States

Determine the proxies for the neuronal states. An example of the information flow in the neuronal part of the nonlinear forward modulating (nonlin_fwd_mod) is shown in its factor graph below:
.
In the neuronal factor graph (for the nonlinear forwad modulation) above each individual state appears linear in every factor in the neuronal model. We can therefore analytically invert every factor to determine the neuronal state. The typical trajectories of each of the states are shown (red) together with their iterative approximation (grey lines) obtained by variational gradient matching.
No damping is required because we invert all ODEs w.r.t. neuronal populations .
state.proxy.mean{:,{'n_1','n_3','n_2'}} = proxy_for_neuronal_populations(state.neuronal,...
state.proxy.mean{:,symbols.state_string},ode_param.proxy.mean',dC_times_invC,...
coupling_idx.states,symbols,A_plus_gamma_inv,opt_settings);
Keep initial value at zero:
state_idx = cellfun(@(x) ~strcmp(x(1),'u'),symbols.state_string);
state.proxy.mean{:,symbols.state_string(state_idx)} = bsxfun(@minus,state.proxy.mean{:,symbols.state_string(state_idx)},...
state.proxy.mean{1,symbols.state_string(state_idx)});

Proxy for ODE parameters

Expanding the proxy distribution in equation (11) for yields:
,
where we substitute with its density given in equation (6).
No damping is required because we invert all ODEs w.r.t. neuronal couplings .
if i>200 || i==opt_settings.coord_ascent_numb_iter
[ode_param.proxy.mean,ode_param.proxy.inv_cov] = proxy_for_ode_parameters(...
state.proxy.mean{:,symbols.state_string},dC_times_invC,ode_param.lin_comb,...
symbols,A_plus_gamma_inv,opt_settings);
end

Intercept due to Confounding Effects

The BOLD response is given by:
,
where are the BOLD observations, is the BOLD signal change equation and the matrix is given. The intercept is determined by a minimum least squares estimator:
.
bold_signal_change = bold_signal_change_eqn(state.proxy.mean{:,{'v_1','v_3','v_2'}},state.proxy.mean{:,{'q_1','q_3','q_2'}});
intercept = simulation.X0 * (simulation.X0' * simulation.X0)^(-1) * simulation.X0' * (bold_response.obs_old-bold_signal_change);
bold_response.denoised_obs = bold_response.obs_old + intercept;
Intermediate Results:
if i==1 || ~mod(i,2)
plot_results(fig_handle,state.proxy,simulation,ode_param.proxy.mean,plot_handle,symbols,plot_settings,'not_final');
end
end

Numerical Integration with Estimated ODE Parameters

See whether we actually fit the BOLD responses well. Curves are shown in black.
simulation2 = simulation_old; simulation2.ode_param = ode_param.proxy.mean';
[simulation2,obs_to_state_relation] = simulate_state_dynamics_dcm(simulation2,symbols,ode,time,...
plot_settings,state.ext_input,'no plot');
state.proxy.num_int = simulation2.state;
Final Results
plot_results(fig_handle,state.proxy,simulation,ode_param.proxy.mean,plot_handle,symbols,...
plot_settings,'final',simulation2.bold_response_true,simulation.odes,candidate_odes);
Warning: Matrix is singular to working precision.
Warning: Matrix is singular to working precision.

Time Taken

disp(['time taken: ' num2str(toc) ' seconds'])
time taken: 110.5325 seconds

References

Gorbach, N.S. Validation and Inference of Structural Connectivity and Neural Dynamics with MRI data. 2018. ETH Zürich Doctoral Thesis. https://www.research-collection.ethz.ch/handle/20.500.11850/261734.
Gorbach, N.S. , Bauer, S. and Buhmann, J.M., Scalable Variational Inference for Dynamical Systems. 2017a. Neural Information Processing Systems (NIPS). Link to NIPS paper here and arxiv paper here.
Bauer, S. , Gorbach, N.S. and Buhmann, J.M., Efficient and Flexible Inference for Stochastic Differential Equations. 2017b. Neural Information Processing Systems (NIPS). Link to NIPS paper here.
Wenk, P., Gotovos, A., Bauer, S., Gorbach, N.S., Krause, A. and Buhmann, J.M., Fast Gaussian Process Based Gradient Matching for Parameters Identification in Systems of Nonlinear ODEs. 2018. In submission to Conference on Uncertainty in Artificial Intelligence (UAI). Link to arxiv paper here.
Calderhead, B., Girolami, M. and Lawrence. N.D., 2002. Accelerating Bayesian inference over nonlinear differential equation models. In Advances in Neural Information Processing Systems (NIPS) . 22.
The authors in bold font have contributed equally to their respective papers.